// Package client 是独立的acs客户端。
package client

import (
	"acs/comet/proto"
	"acs/pbmodel"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"net"
	"runtime/debug"
	"sync"
	"sync/atomic"
	"time"

	log "github.com/cihub/seelog"
	pb "github.com/golang/protobuf/proto"
)

const MAX_TRANSACTION_ID = 32765

// 默认的请求超时时间(发送请求或获取返回结果), 单位: 毫秒
const DEFAULT_CONN_TIMEOUT = 11000

// DefaultConnDataReadTimeout 连接上默认的每次调用Read方法时的数据读取超时.
const DefaultConnDataReadTimeout = time.Duration(time.Millisecond * 1)

// 向远端发送的最大请求等待数
const MAX_BLOCKED_REQUESTS = 50000

// 处理远端请求的最大等待数
const MAX_BLOCKED_REQUESTS_OF_REMOTE = 1000

// 心跳包发送间隔, 单位: 毫秒. 一般不应小于DEFAULT_CONN_TIMEOUT, 否则可造成数据读取超时的错误.
const HEARBEAT_INTERVAL = 8000

var (
	ErrTooManyBlockedRequests  = errors.New(fmt.Sprintf("Too many blocked requests over %v", MAX_BLOCKED_REQUESTS))
	ErrEmptyHandshakeCmdString = errors.New("Empty handshake cmd string!")
)

// Message a interface wrapper for protobuf.Message with "New" method which reate a new corresponding protobuf.Message.
type Message interface {
	pb.Message
	New() pb.Message
}

type responseData struct {
	data         []byte
	responseType proto.PacketType
}
type responseContainer struct {
	sync.RWMutex
	// [transId]data
	respDataChans map[uint16]chan *responseData
}

type requestContainer struct {
	sync.RWMutex
	// [transId]data
	requestDataChans map[uint16]chan *[]byte
}

type HandshakeMessage struct {
	Cmd    string
	ReqMsg pb.Message
	Resp   pb.Message
}

func (rc *responseContainer) add(transId uint16) {
	rc.Lock()
	defer rc.Unlock()
	rc.respDataChans[transId] = make(chan *responseData, 1)
}

func (rc *responseContainer) delete(transId uint16) {
	rc.Lock()
	defer rc.Unlock()
	if ch, ok := rc.respDataChans[transId]; ok {
		close(ch)
		delete(rc.respDataChans, transId)
	}
}

func (rc *responseContainer) getBlockedReqLength() int {
	rc.RLock()
	defer rc.RUnlock()
	return len(rc.respDataChans)
}

func createConn(addr string) (net.Conn, error) {
	return net.DialTimeout("tcp", addr, time.Millisecond*DEFAULT_CONN_TIMEOUT)
}

type RequestHandler func(pb.Message) (pb.Message, error)

type requestHandlerWrapper struct {
	handler   RequestHandler
	reqMsg    Message
	resultMsg pb.Message
}
type handlerContainer struct {
	*sync.RWMutex
	handlers map[string]*requestHandlerWrapper
}

type Client struct {
	cfg                *ClientConfig
	conn               net.Conn
	transIdCounter     uint16
	transIdCounterLock *sync.RWMutex
	responseContainer  *responseContainer
	requestContainer   *requestContainer
	requestChan        chan *proto.Request
	requestHandlers    *handlerContainer
	// 同一时刻只允许一个routine进行消息发送
	sendLock *sync.RWMutex
	// 同一时刻只允许一个routine对连接上的数据进行读取.
	readLock          *sync.RWMutex
	err               error
	inProcessingCount int32
	logger            log.LoggerInterface
}

type ClientConfig struct {
	// IP:PORT
	Addr         string
	// In milliseconds
	HeartBeatTtl int64
	// 创建连接时会用到
	HandshakeMessage HandshakeMessage
	Loglevel         LogLevel
	Logger           log.LoggerInterface
	DataEncryptKey   string
	DataEncryptIV    string
}

func NewClient(config ClientConfig) (*Client, error) {
	if config.HandshakeMessage.Cmd == "" {
		return nil, ErrEmptyHandshakeCmdString
	}
	conn, err := createConn(config.Addr)
	if err != nil {
		return nil, err
	}
	if config.HeartBeatTtl < 1 {
		config.HeartBeatTtl = HEARBEAT_INTERVAL
	}
	client := &Client{conn: conn,
		transIdCounterLock: &sync.RWMutex{},
		responseContainer: &responseContainer{
			respDataChans: make(map[uint16]chan *responseData),
		},
		requestContainer: &requestContainer{
			requestDataChans: make(map[uint16]chan *[]byte),
		},
		requestHandlers: &handlerContainer{
			RWMutex:  new(sync.RWMutex),
			handlers: map[string]*requestHandlerWrapper{},
		},
		requestChan: make(chan *proto.Request, 1000),
		sendLock:    &sync.RWMutex{},
		readLock:    &sync.RWMutex{},
		cfg:         &config,
	}
	if config.Logger != nil {
		client.logger = config.Logger
	} else {
		client.logger = log.Default
	}
	go client.readPacket()
	go client.handleRemoteRequest()
	err = client.doHandshake()
	if err != nil {
		return nil, errors.New("Handshake error: " + err.Error())
	}
	go client.keepHeartbeat()
	return client, nil
}

// ParseRequestData  解析request数据.
// rawData 包含cmd及数据的data.
func (c *Client) ParseRequestData(transID uint16, rawData *[]byte) (req *proto.Request, err error) {
	cmdB := []byte{}
	tLength := len(*rawData)
	for i := 0; i < tLength; i++ {
		if (*rawData)[i] == '\n' {
			break
		} else {
			cmdB = append(cmdB, (*rawData)[i])
		}
	}
	cmd := string(cmdB)
	if cmd == "" {
		return nil, errors.New("request cmd is empty")
	}
	var data []byte
	cmdLength := len(cmd)
	if cmdLength < tLength {
		data = (*rawData)[cmdLength+1:]
	} else {
		data = []byte{}
	}
	data, err = proto.Decrypt(data, c.cfg.DataEncryptKey, c.cfg.DataEncryptIV)
	if err != nil {
		return
	}
	req = &proto.Request{
		Length:     uint32(tLength - cmdLength),
		Tsid:       transID,
		Data:       data,
		Cmd:        cmd,
		EncryptCfg: proto.EncryptCfg{EncryptKey: c.cfg.DataEncryptKey, EncryptIV: c.cfg.DataEncryptIV},
	}
	return
}

// setErrorAndRebuild 当出现协议或链接错误时调用此方法.
// 一般只需在read协层里调用即可,因为read时会不停地检查错误.
func (client *Client) setErrorAndRebuild(err error) {
	client.transIdCounterLock.Lock()
	client.err = err
	client.transIdCounterLock.Unlock()
	client.rebuildConn()
}

func (client *Client) rebuildConn() {
	for {
		client.logger.Info("Reconnecting...")
		client.conn.Close()
		reConnStart := time.Now()
		conn, err := createConn(client.cfg.Addr)
		if err == nil {
			client.sendLock.Lock()
			client.readLock.Lock()
			client.logger.Infof("Reconnected to %v .", client.cfg.Addr)
			client.conn = conn
			client.transIdCounterLock.Lock()
			client.err = nil
			client.transIdCounterLock.Unlock()
			client.sendLock.Unlock()
			client.readLock.Unlock()
			client.doHandshake()
			break
		}
		client.logger.Infof("Reconnect to %v err: %v.....try again later...", client.cfg.Addr, err)
		du := time.Millisecond*DEFAULT_CONN_TIMEOUT - time.Now().Sub(reConnStart)
		if du > 0 {
			<-time.After(du)
		}
	}
}
func (client *Client) getError() error {
	client.transIdCounterLock.RLock()
	defer client.transIdCounterLock.RUnlock()
	err := client.err
	return err
}

func (client *Client) keepHeartbeat() {
	tick := time.Tick(time.Millisecond * time.Duration(client.cfg.HeartBeatTtl))
	for {
		<-tick
		tid, res, err := client.SendMessage(proto.CmdPing, &pbmodel.Ping{}, &pbmodel.PingResp{})
		if err != nil {
			client.logger.Errorf("Heartbeat error: %v", err)
		}
		client.logger.Debugf("Heartbeat data: tid: %v, response: %v", tid, res)
	}
}

func (client *Client) doHandshake() error {
	_, _, err := client.SendMessage(client.cfg.HandshakeMessage.Cmd,
		client.cfg.HandshakeMessage.ReqMsg,
		client.cfg.HandshakeMessage.Resp)
	return err
}

//
// req 的类型与RequestHandler参数的类型一致.
// result 为RequestHandler应当返回的类型。
func (client *Client) RegisterRequestHandler(hdl RequestHandler, cmd string, req Message, result pb.Message) {
	client.requestHandlers.Lock()
	defer client.requestHandlers.Unlock()
	client.requestHandlers.handlers[cmd] = &requestHandlerWrapper{
		handler:   hdl,
		reqMsg:    req,
		resultMsg: result,
	}
}

func (client *Client) getHanlder(cmd string) *requestHandlerWrapper {
	client.requestHandlers.RLock()
	defer client.requestHandlers.RUnlock()
	if hdl, ok := client.requestHandlers.handlers[cmd]; ok {
		return hdl
	} else {
		return nil
	}
}

func (client *Client) SendMessage(cmd string, msg pb.Message, resp pb.Message) (transId uint16, re pb.Message, err error) {
	if client.responseContainer.getBlockedReqLength() >= MAX_BLOCKED_REQUESTS {
		err = ErrTooManyBlockedRequests
		return
	}
	err = client.getError()
	if err != nil {
		return
	}
	var data []byte
	data, err = pb.Marshal(msg)
	if err != nil {
		err = errors.New("msg marshal error: " + err.Error())
		return
	}

	transId = client.genTransactionID()
	client.responseContainer.add(transId)
	client.logger.Debugf("Sending cmd[%v]: transactionID[%v]...", cmd, transId)
	data = proto.NewRequest(transId, cmd, data, client.cfg.DataEncryptKey, client.cfg.DataEncryptIV).Encode()
	sendN := len(data)
	var n, sentN int
	client.sendLock.Lock()
	for sentN < sendN {
		client.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(DEFAULT_CONN_TIMEOUT)))
		n, err = client.conn.Write(data[sentN:])
		if err != nil {
			client.responseContainer.delete(transId)
			client.sendLock.Unlock()
			client.setErrorAndRebuild(err)
			return
		} else {
			sentN += n
		}
	}
	client.sendLock.Unlock()
	client.logger.Debugf("Waiting response for cmd[%v]: transactionID[%v]...", cmd, transId)
	re, err = client.getResponse(transId, resp)
	if err == nil {
		client.logger.Debugf("Got response for cmd[%v]: transactionID[%v]...", cmd, transId)
	} else {
		client.logger.Warnf("Failed to get response for cmd[%v]: transactionID[%v]...Err: %v", cmd, transId, err)
	}
	return transId, re, err
}

// getResponse 获取远端的处理响应。
func (c *Client) getResponse(transID uint16, msgType pb.Message) (msg pb.Message, err error) {
	du := time.Millisecond * DEFAULT_CONN_TIMEOUT
	timeout := time.After(du)
	var rd *responseData
	c.responseContainer.RLock()
	respChan := c.responseContainer.respDataChans[transID]
	c.responseContainer.RUnlock()
	select {
	case rd = <-respChan:
		c.responseContainer.delete(transID)
	case <-timeout:
		return nil, fmt.Errorf("Timedout(%v) for getting response for transaction: %v", du, transID)
	}
	data, err := proto.Decrypt(rd.data, c.cfg.DataEncryptKey, c.cfg.DataEncryptIV)
	if err != nil {
		return nil, fmt.Errorf("response data decrypt failed: %v", err)
	}
	rd.data = data
	if rd.responseType == proto.PacketTypeProtoError {
		protoErr := new(pbmodel.ProtoErr)
		err = pb.Unmarshal(rd.data, protoErr)
		if err == nil {
			err = fmt.Errorf("Got proto error for transaction[%v].", transID)
			return
		}
	} else {
		err = pb.Unmarshal(rd.data, msgType)
	}
	if err != nil {
		err = fmt.Errorf("transaction(%v): response msg unmarshal error: %v -- DATA: %s", transID, err, rd.data)
		return nil, err
	}
	return msgType, nil
}

// sendResponse 向远端发送处理结果.
func (c *Client) sendResponse(transID uint16, msg pb.Message) error {
	err := c.getError()
	if err != nil {
		return err
	}
	d, err := pb.Marshal(msg)
	if err != nil {
		return err
	}
	resp := proto.NewResponse(transID, d, c.cfg.DataEncryptKey, c.cfg.DataEncryptIV)
	data := resp.Encode()
	c.sendLock.Lock()
	defer c.sendLock.Unlock()
	c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(DEFAULT_CONN_TIMEOUT)))
	_, err = c.conn.Write(data)
	return err
}

// readPacket loop reading packet from remote end.
func (client *Client) readPacket() {
	var err, protoErr error
	for {
		if protoErr != nil {
			client.setErrorAndRebuild(err)
		}
		tagLengthB := make([]byte, proto.HeadLength)
		client.readLock.Lock()
		_, protoErr = io.ReadFull(client.conn, tagLengthB)
		if protoErr != nil {
			client.readLock.Unlock()
			continue
		}
		tagLength := binary.BigEndian.Uint32(tagLengthB)
		// TODO: 版本号暂时无用
		verB := make([]byte, proto.VerLength)
		_, protoErr = io.ReadFull(client.conn, verB)
		if protoErr != nil {
			client.readLock.Unlock()
			continue
		} else {
			ver := uint8(verB[0])
			client.logger.Debugf("Got server protocol version %v", ver)
		}

		transIDTypeB := make([]byte, proto.TsidTypeLen)
		_, protoErr = io.ReadFull(client.conn, transIDTypeB)
		if protoErr != nil {
			client.readLock.Unlock()
			continue
		}
		transIDType := binary.BigEndian.Uint16(transIDTypeB)
		transID := transIDType >> 2
		data := make([]byte, tagLength)
		_, protoErr = io.ReadFull(client.conn, data)
		if protoErr != nil {
			client.readLock.Unlock()
			continue
		}
		client.readLock.Unlock()
		transType := proto.PacketType(transIDType & 3)
		switch transType {
		case proto.PacketTypeResponse, proto.PacketTypeProtoError:
			client.responseContainer.RLock()
			if _, ok := client.responseContainer.respDataChans[transID]; !ok {
				client.logger.Errorf("Response transaction id(%v) not known! data: %v", transID, data)
			} else {
				client.responseContainer.respDataChans[transID] <- &responseData{data: data, responseType: transType}
			}
			client.responseContainer.RUnlock()
		case proto.PacketTypeRequest:
			req, err := client.ParseRequestData(transID, &data)
			if err != nil {
				client.logger.Errorf("Failed to parse request data: %v", err)
			}
			client.requestChan <- req
		default:
			client.logger.Errorf("Invalid transaction type[%v].", transType)
		}
	}
}

func (client *Client) genTransactionID() uint16 {
	client.transIdCounterLock.Lock()
	defer client.transIdCounterLock.Unlock()
	client.transIdCounter = client.transIdCounter%MAX_TRANSACTION_ID + 1
	return client.transIdCounter
}

func (client *Client) handleRemoteRequest() {
	for {
		if atomic.LoadInt32(&client.inProcessingCount) > MAX_BLOCKED_REQUESTS_OF_REMOTE {
			client.checkAndWaitForRequestProcessing()
		}
		req := <-client.requestChan
		hdlWrap := client.getHanlder(req.Cmd)
		if hdlWrap == nil {
			client.logger.Warnf("No handler for cmd[%v], data: %s", req.Cmd, req.Data)
		} else {
			go func(hdlW *requestHandlerWrapper, req *proto.Request) {
				atomic.AddInt32(&client.inProcessingCount, 1)
				defer func() {
					err := recover()
					if err != nil {
						client.logger.Errorf("Request handler fatal: %v. %s", err, debug.Stack())
					}
					atomic.AddInt32(&client.inProcessingCount, -1)
				}()
				reqMsg := hdlW.reqMsg.New()
				err := pb.Unmarshal(req.Data, reqMsg)
				if err != nil {
					client.logger.Errorf("Request data parse error: %v", err)
					return
				}
				client.logger.Debugf("Handling acs request [%v], transactionID [%v].", req.Cmd, req.Tsid)
				result, err := hdlW.handler(reqMsg)
				if err != nil {
					client.logger.Warnf("Request [%v] trasactionID[%v] handle error: %v", req.Cmd, req.Tsid, err)
				} else {
					client.logger.Debugf("Request [%v] trasactionID[%v] handle success, result: %+v", req.Cmd, req.Tsid, result)
				}
				err = client.sendResponse(req.Tsid, result)
				if err != nil {
					client.logger.Errorf("Reqeust [%v]  transactionID [%v] response send error: %v", req.Cmd, req.Tsid, err)
				}
			}(hdlWrap, req)
		}
	}
}

func (client *Client) checkAndWaitForRequestProcessing() {
	du := time.Millisecond * 30
	tick := time.NewTicker(du)
	for {
		<-tick.C
		if atomic.LoadInt32(&client.inProcessingCount) < MAX_BLOCKED_REQUESTS_OF_REMOTE {
			break
		}
	}

}

// ReadFullWithTimeout read as much bytes as the buf length, a timeout error will ocurr during each inner "Read" calls.
func ReadFullWithTimeout(conn net.Conn, buf []byte, timeout time.Duration) (err error) {
	lenNeeded := len(buf)
	var n int
	for lenNeeded > 0 {
		err = conn.SetReadDeadline(time.Now().Add(timeout))
		if err != nil {
			return
		}
		n, err = conn.Read(buf)
		if err != nil {
			return
		}
		lenNeeded -= n
	}
	return nil
}
